package dbtool

import (
	"errors"
	"fmt"
)

// 操作符
const (
	OpEqual            = "eq"  // 等于
	OpNotEqual         = "ne"  // 不等于
	OpLessThan         = "lt"  // 小于
	OpLessThanEqual    = "lte" // 小于等于
	OpGreaterThan      = "gt"  // 大于
	OpGreaterThanEqual = "gte" // 大于等于
	OpIn               = "in"  // 包含
	OpLike             = "lk"  // 模糊查询
)

// 连接符
const (
	ConnAnd = "and" // 与
	ConnOr  = "or"  // 或
)

// 原子级条件项
type ConditionItem struct {
	Left  string      `json:"left"`  // 左值，一般为列名称
	Right interface{} `json:"right"` // 右值，一般为比较数值
	Op    string      `json:"op"`
}

func (m ConditionItem) QuerySql() (whereSql string, whereArgs interface{}, err error) {
	whereArgs = m.Right
	switch m.Op {
	case OpEqual:
		whereSql = fmt.Sprintf("%s = ?", m.Left)
	case OpNotEqual:
		whereSql = fmt.Sprintf("%s <> ?", m.Left)
	case OpLessThan:
		whereSql = fmt.Sprintf("%s < ?", m.Left)
	case OpLessThanEqual:
		whereSql = fmt.Sprintf("%s <= ?", m.Left)
	case OpGreaterThan:
		whereSql = fmt.Sprintf("%s > ?", m.Left)
	case OpGreaterThanEqual:
		whereSql = fmt.Sprintf("%s >= ?", m.Left)
	case OpIn:
		whereSql = fmt.Sprintf("%s IN (?)", m.Left)
	case OpLike:
		whereSql = fmt.Sprintf("%s LIKE ?", m.Left)
	default:
		err = errors.New(fmt.Sprintf("目前不支持%s操作符", m.Op))
	}
	return
}

// 复合条件项
type Condition struct {
	Subs  []Condition     `json:"subs"`  // 复合条件项
	Items []ConditionItem `json:"items"` // 原子级条件项
	Conn  string          `json:"conn"`  // 连接符
}

func (m Condition) Fields() (result []string) {
	r := make(map[string]bool)
	m.fields(&r)
	for key, _ := range r {
		result = append(result, key)
	}
	return
}

func (m Condition) fields(result *map[string]bool) {
	for _, item := range m.Items {
		_, ok := (*result)[item.Left]
		if !ok {
			(*result)[item.Left] = true
		}
	}
	for _, cond := range m.Subs {
		cond.fields(result)
	}
	return
}

func (m Condition) QuerySql() (whereSql string, whereArgs []interface{}, err error) {
	var whereSqlItems []string
	for _, item := range m.Items {
		var (
			itemSql string
			itemArg interface{}
		)
		if itemSql, itemArg, err = item.QuerySql(); err != nil {
			return
		}
		whereSqlItems = append(whereSqlItems, itemSql)
		whereArgs = append(whereArgs, itemArg)
	}
	for _, cond := range m.Subs {
		var (
			itemSql string
			itemArg []interface{}
		)
		if itemSql, itemArg, err = cond.QuerySql(); err != nil {
			return
		}
		whereSqlItems = append(whereSqlItems, itemSql)
		whereArgs = append(whereArgs, itemArg...)
	}
	for _, sqlItem := range whereSqlItems {
		if len(whereSql) > 0 {
			whereSql = fmt.Sprintf("%s %s (%s)", whereSql, m.Conn, sqlItem)
		} else {
			whereSql = fmt.Sprintf("(%s)", sqlItem)
		}
	}
	return
}
